e618bd
@@ -104,20 +104,7 @@
 import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
 import org.apache.hadoop.hive.ql.plan.GroupByDesc;
 import org.apache.hadoop.hive.ql.udf.SettableUDF;
-import org.apache.hadoop.hive.ql.udf.UDFConv;
-import org.apache.hadoop.hive.ql.udf.UDFFromUnixTime;
-import org.apache.hadoop.hive.ql.udf.UDFHex;
-import org.apache.hadoop.hive.ql.udf.UDFRegExpExtract;
-import org.apache.hadoop.hive.ql.udf.UDFRegExpReplace;
-import org.apache.hadoop.hive.ql.udf.UDFSign;
-import org.apache.hadoop.hive.ql.udf.UDFToBoolean;
-import org.apache.hadoop.hive.ql.udf.UDFToByte;
-import org.apache.hadoop.hive.ql.udf.UDFToDouble;
-import org.apache.hadoop.hive.ql.udf.UDFToFloat;
-import org.apache.hadoop.hive.ql.udf.UDFToInteger;
-import org.apache.hadoop.hive.ql.udf.UDFToLong;
-import org.apache.hadoop.hive.ql.udf.UDFToShort;
-import org.apache.hadoop.hive.ql.udf.UDFToString;
+import org.apache.hadoop.hive.ql.udf.*;
 import org.apache.hadoop.hive.ql.udf.generic.*;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
 import org.apache.hadoop.hive.serde2.ByteStream.Output;
@@ -359,6 +346,67 @@
public void addProjectionColumn(String columnName, int vectorBatchColIndex) {
     castExpressionUdfs.add(UDFToShort.class);
   }
 
+  // Set of GenericUDFs which require need implicit type casting of decimal parameters.
+  // Vectorization for mathmatical functions currently depends on decimal params automatically
+  // being converted to the return type (see getImplicitCastExpression()), which is not correct
+  // in the general case. This set restricts automatic type conversion to just these functions.
+  private static Set<Class<?>> udfsNeedingImplicitDecimalCast = new HashSet<Class<?>>();
+  static {
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPPlus.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPMinus.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPMultiply.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPDivide.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPMod.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFRound.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFBRound.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFFloor.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFCbrt.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFCeil.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFAbs.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFPosMod.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFPower.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFFactorial.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPPositive.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPNegative.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFCoalesce.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFElt.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFGreatest.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFLeast.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFIn.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPEqual.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPEqualNS.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPNotEqual.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPLessThan.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPEqualOrLessThan.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPGreaterThan.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFOPEqualOrGreaterThan.class);
+    udfsNeedingImplicitDecimalCast.add(GenericUDFBetween.class);
+    udfsNeedingImplicitDecimalCast.add(UDFSqrt.class);
+    udfsNeedingImplicitDecimalCast.add(UDFRand.class);
+    udfsNeedingImplicitDecimalCast.add(UDFLn.class);
+    udfsNeedingImplicitDecimalCast.add(UDFLog2.class);
+    udfsNeedingImplicitDecimalCast.add(UDFSin.class);
+    udfsNeedingImplicitDecimalCast.add(UDFAsin.class);
+    udfsNeedingImplicitDecimalCast.add(UDFCos.class);
+    udfsNeedingImplicitDecimalCast.add(UDFAcos.class);
+    udfsNeedingImplicitDecimalCast.add(UDFLog10.class);
+    udfsNeedingImplicitDecimalCast.add(UDFLog.class);
+    udfsNeedingImplicitDecimalCast.add(UDFExp.class);
+    udfsNeedingImplicitDecimalCast.add(UDFDegrees.class);
+    udfsNeedingImplicitDecimalCast.add(UDFRadians.class);
+    udfsNeedingImplicitDecimalCast.add(UDFAtan.class);
+    udfsNeedingImplicitDecimalCast.add(UDFTan.class);
+    udfsNeedingImplicitDecimalCast.add(UDFOPLongDivide.class);
+  }
+
+  protected boolean needsImplicitCastForDecimal(GenericUDF udf) {
+    Class<?> udfClass = udf.getClass();
+    if (udf instanceof GenericUDFBridge) {
+      udfClass = ((GenericUDFBridge) udf).getUdfClass();
+    }
+    return udfsNeedingImplicitDecimalCast.contains(udfClass);
+  }
+
   protected int getInputColumnIndex(String name) throws HiveException {
     if (name == null) {
       throw new HiveException("Null column name");
@@ -764,24 +812,26 @@
private ExprNodeDesc getImplicitCastExpression(GenericUDF udf, ExprNodeDesc chil
     }
 
     if (castTypeDecimal && !inputTypeDecimal) {
-
-      // Cast the input to decimal
-      // If castType is decimal, try not to lose precision for numeric types.
-      castType = updatePrecision(inputTypeInfo, (DecimalTypeInfo) castType);
-      GenericUDFToDecimal castToDecimalUDF = new GenericUDFToDecimal();
-      castToDecimalUDF.setTypeInfo(castType);
-      List<ExprNodeDesc> children = new ArrayList<ExprNodeDesc>();
-      children.add(child);
-      ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, castToDecimalUDF, children);
-      return desc;
+      if (needsImplicitCastForDecimal(udf)) {
+        // Cast the input to decimal
+        // If castType is decimal, try not to lose precision for numeric types.
+        castType = updatePrecision(inputTypeInfo, (DecimalTypeInfo) castType);
+        GenericUDFToDecimal castToDecimalUDF = new GenericUDFToDecimal();
+        castToDecimalUDF.setTypeInfo(castType);
+        List<ExprNodeDesc> children = new ArrayList<ExprNodeDesc>();
+        children.add(child);
+        ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, castToDecimalUDF, children);
+        return desc;
+      }
     } else if (!castTypeDecimal && inputTypeDecimal) {
-
-      // Cast decimal input to returnType
-      GenericUDF genericUdf = getGenericUDFForCast(castType);
-      List<ExprNodeDesc> children = new ArrayList<ExprNodeDesc>();
-      children.add(child);
-      ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, genericUdf, children);
-      return desc;
+      if (needsImplicitCastForDecimal(udf)) {
+        // Cast decimal input to returnType
+        GenericUDF genericUdf = getGenericUDFForCast(castType);
+        List<ExprNodeDesc> children = new ArrayList<ExprNodeDesc>();
+        children.add(child);
+        ExprNodeDesc desc = new ExprNodeGenericFuncDesc(castType, genericUdf, children);
+        return desc;
+      }
     } else {
 
       // Casts to exact types including long to double etc. are needed in some special cases.
